• DOMAIN: Telecom • CONTEXT: A telecom company wants to use their historical customer data to predict behaviour to retain customers. You can analyse all relevant customer data and develop focused customer retention programs. • DATA DESCRIPTION: Each row represents a customer, each column contains customer’s attributes described on the column Metadata. The data set includes information about: • Customers who left within the last month – the column is called Churn • Services that each customer has signed up for – phone, multiple lines, internet, online security, online backup, device protection, tech support, and streaming TV and movies • Customer account information – how long they’ve been a customer, contract, payment method, paperless billing, monthly charges, and total charges • Demographic info about customers – gender, age range, and if they have partners and dependents • PROJECT OBJECTIVE: Build a model that will help to identify the potential customers who have a higher probability to churn. This help the company to understand the pinpoints and patterns of customer churn and will increase the focus on strategising customer retention. • Steps to the project: [ Total score: 60 points ]
#import libabries
import pandas as pd
import numpy as np
from sklearn import metrics
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
#from sklearn.feature_extraction.text import CountVectorizer #DT does not take strings as input for the model fit step....
from IPython.display import Image
#import pydotplus as pydot
from sklearn import tree
from os import system
data = pd.read_csv('TelcomCustomer-Churn.csv')
data.head()
| customerID | gender | SeniorCitizen | Partner | Dependents | tenure | PhoneService | MultipleLines | InternetService | OnlineSecurity | ... | DeviceProtection | TechSupport | StreamingTV | StreamingMovies | Contract | PaperlessBilling | PaymentMethod | MonthlyCharges | TotalCharges | Churn | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 7590-VHVEG | Female | 0 | Yes | No | 1 | No | No phone service | DSL | No | ... | No | No | No | No | Month-to-month | Yes | Electronic check | 29.85 | 29.85 | No |
| 1 | 5575-GNVDE | Male | 0 | No | No | 34 | Yes | No | DSL | Yes | ... | Yes | No | No | No | One year | No | Mailed check | 56.95 | 1889.5 | No |
| 2 | 3668-QPYBK | Male | 0 | No | No | 2 | Yes | No | DSL | Yes | ... | No | No | No | No | Month-to-month | Yes | Mailed check | 53.85 | 108.15 | Yes |
| 3 | 7795-CFOCW | Male | 0 | No | No | 45 | No | No phone service | DSL | Yes | ... | Yes | Yes | No | No | One year | No | Bank transfer (automatic) | 42.30 | 1840.75 | No |
| 4 | 9237-HQITU | Female | 0 | No | No | 2 | Yes | No | Fiber optic | No | ... | No | No | No | No | Month-to-month | Yes | Electronic check | 70.70 | 151.65 | Yes |
5 rows × 21 columns
data.shape
(7043, 21)
data.size
147903
# getting total number of rows and column in the dataframe
print(f" Shape of the dataframe = {data.shape}")
totalrows=data.shape[0]
print(f" Total number of rows in the dataset = {totalrows}")
Shape of the dataframe = (7043, 21) Total number of rows in the dataset = 7043
print ("Rows : " ,data.shape[0])
print ("Columns : " ,data.shape[1])
print ("\nFeatures : \n" ,data.columns.tolist())
print ("\nMissing values : ", data.isnull().sum().values.sum())
print ("\nUnique values : \n",data.nunique())
Rows : 7043 Columns : 21 Features : ['customerID', 'gender', 'SeniorCitizen', 'Partner', 'Dependents', 'tenure', 'PhoneService', 'MultipleLines', 'InternetService', 'OnlineSecurity', 'OnlineBackup', 'DeviceProtection', 'TechSupport', 'StreamingTV', 'StreamingMovies', 'Contract', 'PaperlessBilling', 'PaymentMethod', 'MonthlyCharges', 'TotalCharges', 'Churn'] Missing values : 0 Unique values : customerID 7043 gender 2 SeniorCitizen 2 Partner 2 Dependents 2 tenure 73 PhoneService 2 MultipleLines 3 InternetService 3 OnlineSecurity 3 OnlineBackup 3 DeviceProtection 3 TechSupport 3 StreamingTV 3 StreamingMovies 3 Contract 3 PaperlessBilling 2 PaymentMethod 4 MonthlyCharges 1585 TotalCharges 6531 Churn 2 dtype: int64
# Checking Null colunms
data_nullcols = data.isnull().sum()
data_nullcols
customerID 0 gender 0 SeniorCitizen 0 Partner 0 Dependents 0 tenure 0 PhoneService 0 MultipleLines 0 InternetService 0 OnlineSecurity 0 OnlineBackup 0 DeviceProtection 0 TechSupport 0 StreamingTV 0 StreamingMovies 0 Contract 0 PaperlessBilling 0 PaymentMethod 0 MonthlyCharges 0 TotalCharges 0 Churn 0 dtype: int64
# Check columns types
data.dtypes
customerID object gender object SeniorCitizen int64 Partner object Dependents object tenure int64 PhoneService object MultipleLines object InternetService object OnlineSecurity object OnlineBackup object DeviceProtection object TechSupport object StreamingTV object StreamingMovies object Contract object PaperlessBilling object PaymentMethod object MonthlyCharges float64 TotalCharges object Churn object dtype: object
data.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 7043 entries, 0 to 7042 Data columns (total 21 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 customerID 7043 non-null object 1 gender 7043 non-null object 2 SeniorCitizen 7043 non-null int64 3 Partner 7043 non-null object 4 Dependents 7043 non-null object 5 tenure 7043 non-null int64 6 PhoneService 7043 non-null object 7 MultipleLines 7043 non-null object 8 InternetService 7043 non-null object 9 OnlineSecurity 7043 non-null object 10 OnlineBackup 7043 non-null object 11 DeviceProtection 7043 non-null object 12 TechSupport 7043 non-null object 13 StreamingTV 7043 non-null object 14 StreamingMovies 7043 non-null object 15 Contract 7043 non-null object 16 PaperlessBilling 7043 non-null object 17 PaymentMethod 7043 non-null object 18 MonthlyCharges 7043 non-null float64 19 TotalCharges 7043 non-null object 20 Churn 7043 non-null object dtypes: float64(1), int64(2), object(18) memory usage: 1.1+ MB
data.mean()
SeniorCitizen 0.162147 tenure 32.371149 MonthlyCharges 64.761692 dtype: float64
data.describe().transpose()
| count | mean | std | min | 25% | 50% | 75% | max | |
|---|---|---|---|---|---|---|---|---|
| SeniorCitizen | 7043.0 | 0.162147 | 0.368612 | 0.00 | 0.0 | 0.00 | 0.00 | 1.00 |
| tenure | 7043.0 | 32.371149 | 24.559481 | 0.00 | 9.0 | 29.00 | 55.00 | 72.00 |
| MonthlyCharges | 7043.0 | 64.761692 | 30.090047 | 18.25 | 35.5 | 70.35 | 89.85 | 118.75 |
data.dtypes.value_counts()
object 18 int64 2 float64 1 dtype: int64
plt.figure(figsize=(20,6))
plt.subplot(1, 3, 1)
plt.title('SeniorCitizen')
sns.distplot(data['SeniorCitizen'],color='green')
# subplot 2
plt.subplot(1, 3, 2)
plt.title('tenure')
sns.distplot(data['tenure'],color='blue')
# subplot 3l
plt.subplot(1, 3, 3)
plt.title('MonthlyCharges')
sns.distplot(data['MonthlyCharges'],color='red')
plt.figure(figsize=(20,6))
plt.subplot(1, 3, 1)
plt.title('SeniorCitizen')
sns.boxplot(data['SeniorCitizen'],orient='vertical',color='green')
# subplot 2
plt.subplot(1, 3, 2)
plt.title('tenure')
sns.boxplot(data['tenure'],orient='vertical',color='blue')
# subplot 3
plt.subplot(1, 3, 3)
plt.title('tenure')
sns.boxplot(data['tenure'],orient='vertical',color='red')
plt.show()
C:\ProgramData\Anaconda3\lib\site-packages\seaborn\distributions.py:2551: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
warnings.warn(msg, FutureWarning)
C:\ProgramData\Anaconda3\lib\site-packages\seaborn\distributions.py:2551: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
warnings.warn(msg, FutureWarning)
C:\ProgramData\Anaconda3\lib\site-packages\seaborn\distributions.py:2551: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
warnings.warn(msg, FutureWarning)
C:\ProgramData\Anaconda3\lib\site-packages\seaborn\_decorators.py:36: FutureWarning: Pass the following variable as a keyword arg: x. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
warnings.warn(
C:\ProgramData\Anaconda3\lib\site-packages\seaborn\_core.py:1303: UserWarning: Vertical orientation ignored with only `x` specified.
warnings.warn(single_var_warning.format("Vertical", "x"))
C:\ProgramData\Anaconda3\lib\site-packages\seaborn\_decorators.py:36: FutureWarning: Pass the following variable as a keyword arg: x. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
warnings.warn(
C:\ProgramData\Anaconda3\lib\site-packages\seaborn\_core.py:1303: UserWarning: Vertical orientation ignored with only `x` specified.
warnings.warn(single_var_warning.format("Vertical", "x"))
C:\ProgramData\Anaconda3\lib\site-packages\seaborn\_decorators.py:36: FutureWarning: Pass the following variable as a keyword arg: x. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
warnings.warn(
C:\ProgramData\Anaconda3\lib\site-packages\seaborn\_core.py:1303: UserWarning: Vertical orientation ignored with only `x` specified.
warnings.warn(single_var_warning.format("Vertical", "x"))
sns.countplot(data['MonthlyCharges']);
C:\ProgramData\Anaconda3\lib\site-packages\seaborn\_decorators.py:36: FutureWarning: Pass the following variable as a keyword arg: x. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation. warnings.warn(
sns.boxplot(data['SeniorCitizen'],data['tenure']);
C:\ProgramData\Anaconda3\lib\site-packages\seaborn\_decorators.py:36: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation. warnings.warn(
rs = np.random.RandomState(0)
df = pd.DataFrame(rs.rand(10,10))
corr = data.corr()
corr.style
| SeniorCitizen | tenure | MonthlyCharges | |
|---|---|---|---|
| SeniorCitizen | 1.000000 | 0.016567 | 0.220173 |
| tenure | 0.016567 | 1.000000 | 0.247900 |
| MonthlyCharges | 0.220173 | 0.247900 | 1.000000 |
corr.dtypes
SeniorCitizen float64 tenure float64 MonthlyCharges float64 dtype: object
plt.figure(figsize=(20,12))
plt.subplot(1,2,1)
sns.boxplot(x = 'SeniorCitizen', y = 'tenure', data = data)
plt.subplot(1,2,2)
sns.boxplot(x = 'MonthlyCharges', y = 'tenure', data = data)
<AxesSubplot:xlabel='MonthlyCharges', ylabel='tenure'>
sns.histplot(x = 'SeniorCitizen', y = 'tenure', data = data)
<AxesSubplot:xlabel='SeniorCitizen', ylabel='tenure'>
sns.pairplot(data)
<seaborn.axisgrid.PairGrid at 0x1b7bc8c9910>
sns.scatterplot(data['SeniorCitizen'], data['tenure']) # Plots the scatter plot using two variables
C:\ProgramData\Anaconda3\lib\site-packages\seaborn\_decorators.py:36: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation. warnings.warn(
<AxesSubplot:xlabel='SeniorCitizen', ylabel='tenure'>
sns.jointplot(data['tenure'],data['MonthlyCharges'],color='green');
C:\ProgramData\Anaconda3\lib\site-packages\seaborn\_decorators.py:36: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation. warnings.warn(
sns.heatmap(data.corr(), annot=True) # plot the correlation coefficients as a heatmap
<AxesSubplot:>
Lets convert the columns with an 'object' datatype into categorical variables
for feature in data.columns: # Loop through all columns in the dataframe
if data[feature].dtype == 'object': # Only apply for columns with categorical strings
data[feature] = pd.Categorical(data[feature])# Replace strings with an integer
data.head(10)
| customerID | gender | SeniorCitizen | Partner | Dependents | tenure | PhoneService | MultipleLines | InternetService | OnlineSecurity | ... | DeviceProtection | TechSupport | StreamingTV | StreamingMovies | Contract | PaperlessBilling | PaymentMethod | MonthlyCharges | TotalCharges | Churn | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 7590-VHVEG | Female | 0 | Yes | No | 1 | No | No phone service | DSL | No | ... | No | No | No | No | Month-to-month | Yes | Electronic check | 29.85 | 29.85 | No |
| 1 | 5575-GNVDE | Male | 0 | No | No | 34 | Yes | No | DSL | Yes | ... | Yes | No | No | No | One year | No | Mailed check | 56.95 | 1889.5 | No |
| 2 | 3668-QPYBK | Male | 0 | No | No | 2 | Yes | No | DSL | Yes | ... | No | No | No | No | Month-to-month | Yes | Mailed check | 53.85 | 108.15 | Yes |
| 3 | 7795-CFOCW | Male | 0 | No | No | 45 | No | No phone service | DSL | Yes | ... | Yes | Yes | No | No | One year | No | Bank transfer (automatic) | 42.30 | 1840.75 | No |
| 4 | 9237-HQITU | Female | 0 | No | No | 2 | Yes | No | Fiber optic | No | ... | No | No | No | No | Month-to-month | Yes | Electronic check | 70.70 | 151.65 | Yes |
| 5 | 9305-CDSKC | Female | 0 | No | No | 8 | Yes | Yes | Fiber optic | No | ... | Yes | No | Yes | Yes | Month-to-month | Yes | Electronic check | 99.65 | 820.5 | Yes |
| 6 | 1452-KIOVK | Male | 0 | No | Yes | 22 | Yes | Yes | Fiber optic | No | ... | No | No | Yes | No | Month-to-month | Yes | Credit card (automatic) | 89.10 | 1949.4 | No |
| 7 | 6713-OKOMC | Female | 0 | No | No | 10 | No | No phone service | DSL | Yes | ... | No | No | No | No | Month-to-month | No | Mailed check | 29.75 | 301.9 | No |
| 8 | 7892-POOKP | Female | 0 | Yes | No | 28 | Yes | Yes | Fiber optic | No | ... | Yes | Yes | Yes | Yes | Month-to-month | Yes | Electronic check | 104.80 | 3046.05 | Yes |
| 9 | 6388-TABGU | Male | 0 | No | Yes | 62 | Yes | No | DSL | Yes | ... | No | No | No | No | One year | No | Bank transfer (automatic) | 56.15 | 3487.95 | No |
10 rows × 21 columns
Get value counts for every column
print(data.gender.value_counts())
print(data.Partner.value_counts())
print(data.Dependents.value_counts())
print(data.PhoneService.value_counts())
print(data.MultipleLines.value_counts())
print(data.InternetService.value_counts())
print(data.OnlineSecurity.value_counts())
print(data.OnlineBackup.value_counts())
print(data.DeviceProtection.value_counts())
print(data.TechSupport.value_counts())
print(data.StreamingTV.value_counts())
print(data.StreamingMovies.value_counts())
print(data.Contract.value_counts())
print(data.PaperlessBilling.value_counts())
print(data.PaymentMethod.value_counts())
print(data.MonthlyCharges.value_counts())
print(data.TotalCharges.value_counts())
print(data.Churn.value_counts())
Male 3555
Female 3488
Name: gender, dtype: int64
No 3641
Yes 3402
Name: Partner, dtype: int64
No 4933
Yes 2110
Name: Dependents, dtype: int64
Yes 6361
No 682
Name: PhoneService, dtype: int64
No 3390
Yes 2971
No phone service 682
Name: MultipleLines, dtype: int64
Fiber optic 3096
DSL 2421
No 1526
Name: InternetService, dtype: int64
No 3498
Yes 2019
No internet service 1526
Name: OnlineSecurity, dtype: int64
No 3088
Yes 2429
No internet service 1526
Name: OnlineBackup, dtype: int64
No 3095
Yes 2422
No internet service 1526
Name: DeviceProtection, dtype: int64
No 3473
Yes 2044
No internet service 1526
Name: TechSupport, dtype: int64
No 2810
Yes 2707
No internet service 1526
Name: StreamingTV, dtype: int64
No 2785
Yes 2732
No internet service 1526
Name: StreamingMovies, dtype: int64
Month-to-month 3875
Two year 1695
One year 1473
Name: Contract, dtype: int64
Yes 4171
No 2872
Name: PaperlessBilling, dtype: int64
Electronic check 2365
Mailed check 1612
Bank transfer (automatic) 1544
Credit card (automatic) 1522
Name: PaymentMethod, dtype: int64
20.05 61
19.85 45
19.95 44
19.90 44
20.00 43
..
114.75 1
103.60 1
113.40 1
57.65 1
113.30 1
Name: MonthlyCharges, Length: 1585, dtype: int64
11
20.2 11
19.75 9
19.65 8
19.9 8
..
5084.65 1
5088.4 1
509.3 1
5099.15 1
3808.2 1
Name: TotalCharges, Length: 6531, dtype: int64
No 5174
Yes 1869
Name: Churn, dtype: int64
replaceStruct = {
"InternetService":{"Fiber optic": 1, "DSL":2,"No":3},
"Contract":{"Month-to-month": 1, "One year": 2,"Two year":3},
"PaymentMethod":{"Electronic check": 1, "Mailed check": 2,"Bank transfer (automatic)":3,"Credit card (automatic)":4},
"Churn": {"No": 1, "Yes": 2} }
oneHotCols=["gender","Partner","Dependents","PhoneService","MultipleLines","OnlineSecurity","OnlineBackup","DeviceProtection","TechSupport","StreamingTV","StreamingMovies","PaperlessBilling"]
data=data.replace(replaceStruct)
data=pd.get_dummies(data, columns=oneHotCols)
data.head(10)
| customerID | SeniorCitizen | tenure | InternetService | Contract | PaymentMethod | MonthlyCharges | TotalCharges | Churn | gender_Female | ... | TechSupport_No internet service | TechSupport_Yes | StreamingTV_No | StreamingTV_No internet service | StreamingTV_Yes | StreamingMovies_No | StreamingMovies_No internet service | StreamingMovies_Yes | PaperlessBilling_No | PaperlessBilling_Yes | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 7590-VHVEG | 0 | 1 | 2 | 1 | 1 | 29.85 | 29.85 | 1 | 1 | ... | 0 | 0 | 1 | 0 | 0 | 1 | 0 | 0 | 0 | 1 |
| 1 | 5575-GNVDE | 0 | 34 | 2 | 2 | 2 | 56.95 | 1889.5 | 1 | 0 | ... | 0 | 0 | 1 | 0 | 0 | 1 | 0 | 0 | 1 | 0 |
| 2 | 3668-QPYBK | 0 | 2 | 2 | 1 | 2 | 53.85 | 108.15 | 2 | 0 | ... | 0 | 0 | 1 | 0 | 0 | 1 | 0 | 0 | 0 | 1 |
| 3 | 7795-CFOCW | 0 | 45 | 2 | 2 | 3 | 42.30 | 1840.75 | 1 | 0 | ... | 0 | 1 | 1 | 0 | 0 | 1 | 0 | 0 | 1 | 0 |
| 4 | 9237-HQITU | 0 | 2 | 1 | 1 | 1 | 70.70 | 151.65 | 2 | 1 | ... | 0 | 0 | 1 | 0 | 0 | 1 | 0 | 0 | 0 | 1 |
| 5 | 9305-CDSKC | 0 | 8 | 1 | 1 | 1 | 99.65 | 820.5 | 2 | 1 | ... | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 1 | 0 | 1 |
| 6 | 1452-KIOVK | 0 | 22 | 1 | 1 | 4 | 89.10 | 1949.4 | 1 | 0 | ... | 0 | 0 | 0 | 0 | 1 | 1 | 0 | 0 | 0 | 1 |
| 7 | 6713-OKOMC | 0 | 10 | 2 | 1 | 2 | 29.75 | 301.9 | 1 | 1 | ... | 0 | 0 | 1 | 0 | 0 | 1 | 0 | 0 | 1 | 0 |
| 8 | 7892-POOKP | 0 | 28 | 1 | 1 | 1 | 104.80 | 3046.05 | 2 | 1 | ... | 0 | 1 | 0 | 0 | 1 | 0 | 0 | 1 | 0 | 1 |
| 9 | 6388-TABGU | 0 | 62 | 2 | 2 | 3 | 56.15 | 3487.95 | 1 | 0 | ... | 0 | 0 | 1 | 0 | 0 | 1 | 0 | 0 | 1 | 0 |
10 rows × 40 columns
data.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 7043 entries, 0 to 7042 Data columns (total 40 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 customerID 7043 non-null category 1 SeniorCitizen 7043 non-null int64 2 tenure 7043 non-null int64 3 InternetService 7043 non-null int64 4 Contract 7043 non-null int64 5 PaymentMethod 7043 non-null int64 6 MonthlyCharges 7043 non-null float64 7 TotalCharges 7043 non-null category 8 Churn 7043 non-null int64 9 gender_Female 7043 non-null uint8 10 gender_Male 7043 non-null uint8 11 Partner_No 7043 non-null uint8 12 Partner_Yes 7043 non-null uint8 13 Dependents_No 7043 non-null uint8 14 Dependents_Yes 7043 non-null uint8 15 PhoneService_No 7043 non-null uint8 16 PhoneService_Yes 7043 non-null uint8 17 MultipleLines_No 7043 non-null uint8 18 MultipleLines_No phone service 7043 non-null uint8 19 MultipleLines_Yes 7043 non-null uint8 20 OnlineSecurity_No 7043 non-null uint8 21 OnlineSecurity_No internet service 7043 non-null uint8 22 OnlineSecurity_Yes 7043 non-null uint8 23 OnlineBackup_No 7043 non-null uint8 24 OnlineBackup_No internet service 7043 non-null uint8 25 OnlineBackup_Yes 7043 non-null uint8 26 DeviceProtection_No 7043 non-null uint8 27 DeviceProtection_No internet service 7043 non-null uint8 28 DeviceProtection_Yes 7043 non-null uint8 29 TechSupport_No 7043 non-null uint8 30 TechSupport_No internet service 7043 non-null uint8 31 TechSupport_Yes 7043 non-null uint8 32 StreamingTV_No 7043 non-null uint8 33 StreamingTV_No internet service 7043 non-null uint8 34 StreamingTV_Yes 7043 non-null uint8 35 StreamingMovies_No 7043 non-null uint8 36 StreamingMovies_No internet service 7043 non-null uint8 37 StreamingMovies_Yes 7043 non-null uint8 38 PaperlessBilling_No 7043 non-null uint8 39 PaperlessBilling_Yes 7043 non-null uint8 dtypes: category(2), float64(1), int64(6), uint8(31) memory usage: 1.3 MB
X = data.drop(["customerID","Churn","TotalCharges"],axis=1)
y = data['Churn']
from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=30, random_state=1)
We will build our model using the DecisionTreeClassifier function. Using default 'gini' criteria to split. Other option include 'entropy'.
from sklearn.tree import DecisionTreeClassifier
dTree = DecisionTreeClassifier(criterion = 'gini', random_state=1)
dTree.fit(X_train, y_train)
DecisionTreeClassifier(random_state=1)
print(dTree.score(X_train, y_train))
print(dTree.score(X_test, y_test))
0.9970055611008127 0.8
Hence, this decision tree model is 'overfitted
Using plot_tree method from sklearn.tree
from sklearn.tree import export_graphviz
train_char_label = ['No', 'Yes']
data_Tree_File = open('data_tree.dot','w')
dot_data = export_graphviz(dTree, out_file=data_Tree_File, feature_names = list(X_train), class_names = list(train_char_label))
data_Tree_File.close()
tree.export_graphviz outputs a .dot file. This is a text file that describes a graph structure using a specific structure. You can plot this by
1. pasting the contents of that file at http://webgraphviz.com/ (or)
2. generate a image file using the 'dot' command (this will only work if you have graphviz installed on your machine)
from os import system
from IPython.display import Image
#Works only if "dot" command works on you machine
retCode = system("dot -Tpng data_tree.dot -o data_tree.png")
if(retCode>0):
print("system command returning error: "+str(retCode))
else:
display(Image("data_tree.png"))
# If graphviz doesn't work, we can use plot_tree method from sklearn.tree
import matplotlib.pyplot as plt
from sklearn.tree import plot_tree
fn = list(X_train)
cn = ['No', 'Yes']
fig, axes = plt.subplots(nrows = 1,ncols = 1,figsize = (4, 4), dpi=300)
plot_tree(dTree, feature_names = fn, class_names=cn, filled = True)
fig.savefig('tree.png')
dTreeR = DecisionTreeClassifier(criterion = 'gini', max_depth = 3, random_state=1)
dTreeR.fit(X_train, y_train)
print(dTreeR.score(X_train, y_train))
print(dTreeR.score(X_test, y_test))
0.7905318693854271 0.9333333333333333
train_char_label = ['No', 'Yes']
data_Tree_FileR = open('data_treeR.dot','w')
dot_data = export_graphviz(dTreeR, out_file=data_Tree_FileR, feature_names = list(X_train), class_names = list(train_char_label))
data_Tree_FileR.close()
#Works only if "dot" command works on you machine
retCode = system("dot -Tpng data_treeR.dot -o data_treeR.png")
if(retCode>0):
print("system command returning error: "+str(retCode))
else:
display(Image("data_treeR.png"))
# If graphviz doesn't work, we can use plot_tree method from sklearn.tree
fn = list(X_train)
cn = ['No', 'Yes']
fig, axes = plt.subplots(nrows = 1,ncols = 1,figsize = (4, 4), dpi=300)
plot_tree(dTree, feature_names = fn, class_names=cn, filled = True)
fig.savefig('tree.png')
# importance of features in the tree building ( The importance of a feature is computed as the
#(normalized) total reduction of the criterion brought by that feature. It is also known as the Gini importance )
print (pd.DataFrame(dTreeR.feature_importances_, columns = ["Imp"], index = X_train.columns))
Imp SeniorCitizen 0.000000 tenure 0.176537 InternetService 0.180214 Contract 0.622753 PaymentMethod 0.000000 MonthlyCharges 0.020497 gender_Female 0.000000 gender_Male 0.000000 Partner_No 0.000000 Partner_Yes 0.000000 Dependents_No 0.000000 Dependents_Yes 0.000000 PhoneService_No 0.000000 PhoneService_Yes 0.000000 MultipleLines_No 0.000000 MultipleLines_No phone service 0.000000 MultipleLines_Yes 0.000000 OnlineSecurity_No 0.000000 OnlineSecurity_No internet service 0.000000 OnlineSecurity_Yes 0.000000 OnlineBackup_No 0.000000 OnlineBackup_No internet service 0.000000 OnlineBackup_Yes 0.000000 DeviceProtection_No 0.000000 DeviceProtection_No internet service 0.000000 DeviceProtection_Yes 0.000000 TechSupport_No 0.000000 TechSupport_No internet service 0.000000 TechSupport_Yes 0.000000 StreamingTV_No 0.000000 StreamingTV_No internet service 0.000000 StreamingTV_Yes 0.000000 StreamingMovies_No 0.000000 StreamingMovies_No internet service 0.000000 StreamingMovies_Yes 0.000000 PaperlessBilling_No 0.000000 PaperlessBilling_Yes 0.000000
Hence, we can observe that tenure, InternetService and Contract play important role for our final prediction
from sklearn.metrics import confusion_matrix
import seaborn as sns
print(dTreeR.score(X_test, y_test))
y_predict = dTreeR.predict(X_test)
cm = confusion_matrix(y_test, y_predict, labels=[1, 2])
df_cm = pd.DataFrame(cm, index = [i for i in ["No","Yes"]],
columns = [i for i in ["No","Yes"]])
plt.figure(figsize = (7,5))
sns.heatmap(df_cm, annot=True ,fmt='g')
0.9333333333333333
<AxesSubplot:>
from sklearn.ensemble import BaggingClassifier
bgcl = BaggingClassifier(base_estimator=dTree, n_estimators=50,random_state=1)
#bgcl = BaggingClassifier(n_estimators=50,random_state=1)
bgcl = bgcl.fit(X_train, y_train)
from sklearn.metrics import confusion_matrix
y_predict = bgcl.predict(X_test)
print(bgcl.score(X_test , y_test))
cm=confusion_matrix(y_test, y_predict,labels=[0, 1])
df_cm = pd.DataFrame(cm, index = [i for i in ["No","Yes"]],
columns = [i for i in ["No","Yes"]])
plt.figure(figsize = (7,5))
sns.heatmap(df_cm, annot=True ,fmt='g')
0.8333333333333334
<AxesSubplot:>
from sklearn.ensemble import AdaBoostClassifier
abcl = AdaBoostClassifier(n_estimators=10, random_state=1)
#abcl = AdaBoostClassifier( n_estimators=50,random_state=1)
abcl = abcl.fit(X_train, y_train)
y_predict = abcl.predict(X_test)
print(abcl.score(X_test , y_test))
cm=confusion_matrix(y_test, y_predict,labels=[0, 1])
df_cm = pd.DataFrame(cm, index = [i for i in ["No","Yes"]],
columns = [i for i in ["No","Yes"]])
plt.figure(figsize = (7,5))
sns.heatmap(df_cm, annot=True ,fmt='g')
0.8666666666666667
<AxesSubplot:>
from sklearn.ensemble import GradientBoostingClassifier
gbcl = GradientBoostingClassifier(n_estimators = 50,random_state=1)
gbcl = gbcl.fit(X_train, y_train)
y_predict = gbcl.predict(X_test)
print(gbcl.score(X_test, y_test))
cm=confusion_matrix(y_test, y_predict,labels=[0, 1])
df_cm = pd.DataFrame(cm, index = [i for i in ["No","Yes"]],
columns = [i for i in ["No","Yes"]])
plt.figure(figsize = (7,5))
sns.heatmap(df_cm, annot=True ,fmt='g')
0.9
<AxesSubplot:>
from sklearn.ensemble import RandomForestClassifier
rfcl = RandomForestClassifier(n_estimators = 50, random_state=1,max_features=12)
rfcl = rfcl.fit(X_train, y_train)
y_predict = rfcl.predict(X_test)
print(rfcl.score(X_test, y_test))
cm=confusion_matrix(y_test, y_predict,labels=[0, 1])
df_cm = pd.DataFrame(cm, index = [i for i in ["No","Yes"]],
columns = [i for i in ["No","Yes"]])
plt.figure(figsize = (7,5))
sns.heatmap(df_cm, annot=True ,fmt='g')
0.9
<AxesSubplot:>
Accuracy on training set for decision tree:0.9970055611008127 Accuracy on training set for regularised decision tree:0.7979176526265973 Accuracy on training set for decision tree with bagging: 0.8333333333333334 Accuracy on training set for decision tree with adaptive boosting: 0.8666666666666667 Accuracy on training set for decision tree with gradient boosting: 0.9 Accuracy on training set for random forest classifier: 0.9
Hence, we can conclude that the decision tree model with adaptive boosting gives us the best accuracy.
Also, observed that bagging classifiers in general benefit from having complex individual models and boosting classifiers in general benefit from having simple models